// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_VECTORFETCHUNIT_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_VECTORFETCHUNIT_H_

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"

template <int NROWS>
SC_MODULE(FetchUnit) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<VectorParams> CCS_INIT_S1(paramsIn);
  Connections::Out<int> CCS_INIT_S1(vectorFetchAddressRequest);

  SC_CTOR(FetchUnit) {
    SC_THREAD(fetch_vector);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  }

  void fetch_vector() {
    paramsIn.Reset();
    vectorFetchAddressRequest.Reset();

    wait();

    while (true) {
      VectorParams params = paramsIn.Pop();

      if (params.SOFTMAX || params.LAYER_NORM) {
        // both currently requires three passes over tensor
        int Y = params.loops[1][params.yLoopIndex[1]];
        int X = params.loops[1][params.xLoopIndex[1]];

        for (int y = 0; y < Y; y++) {
          for (int pass = 0; pass < 3; pass++) {
            for (int x = 0; x < X; x++) {
              int baseAddress = y * X + x;
              vectorFetchAddressRequest.Push(baseAddress +
                                             params.VECTOR_OFFSET);
            }
          }
        }
      }
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SRC_VECTORFETCHUNIT_H_
